In [1]:
import torch
import plotly.express as px
import plotly.io as pio
pio.renderers.default='notebook'
def master_plot_function(
volume: torch.Tensor,
gray_min: float | None = None,
gray_max: float | None = None,
gray_mid: float | None = None,
):
gray_min = volume.min().item() if gray_min is None else gray_min
gray_max = volume.max().item() if gray_max is None else gray_max
gray_mid = (gray_min + gray_max) / 2 if gray_mid is None else gray_mid
max_idx = torch.nonzero(volume == volume.max())
print("Location(s) of maximum value in result:", max_idx.tolist())
z_max = max_idx[0, 0].item() if len(max_idx) > 0 else 0
px.imshow(volume[z_max, :, :],
origin="lower",
color_continuous_scale="viridis",
zmin=gray_min,
zmax=gray_max,
title=f"Slice {z_max} with threshold zmin={gray_min:.1f}",
).show()
from torch_affine_utils.transforms_3d import Rx, Ry, Rz, T
from torch_grid_utils import dft_center
from torch_transform_image import affine_transform_image_3d
In [2]:
image = torch.zeros((28, 28, 28), dtype=torch.float32)
image[14, 7, 14] = 1
image = image.float()
center_dot = torch.zeros((28, 28, 28), dtype=torch.float32)
center_dot[14, 14, 14] = 1
center_dot = center_dot.float()
master_plot_function(image + center_dot*0.1, gray_mid=0)
Location(s) of maximum value in result: [[14, 7, 14]]
In [3]:
center = dft_center(image.shape, fftshift=True, rfft=False)
rotate_zyx=[90, 0, 0]
shifts_zyx=[0, 0, 5]
# shift then rotate
m = (
T(center) @
Rz(rotate_zyx[0], zyx=True) @
Ry(rotate_zyx[1], zyx=True) @
Rx(rotate_zyx[2], zyx=True) @
T(shifts_zyx) @
T(-center)
)
# z = 14, y=7, x=14
print(m @ torch.tensor([14, 7, 14, 1]).float())
# we expect to end up with z=14, y=19, x=21
# linalg.inv
m = torch.linalg.inv(m)
master_plot_function(affine_transform_image_3d(image, m, interpolation='trilinear', zyx_matrices=True) + center_dot*0.1, gray_mid=0)
tensor([14., 19., 21., 1.]) Location(s) of maximum value in result: [[14, 19, 21]]
In [ ]: